Skip to content

Conversation

@kuafou
Copy link
Contributor

@kuafou kuafou commented Oct 29, 2025

Description

This PR fixes a compatibility issue with recent vLLM changes that now require model classes to implement a get_input_embeddings() method.
Without this method, vLLM fails its interface validation during model registration, breaking TPU model integration.

To address this, we add a dummy get_input_embeddings() implementation to the vLLM-compatible wrapper class in
tpu_inference/models/common/model_loader.py.
Similar to the existing dummy forward() method, this implementation only satisfies vLLM’s type checks and raises
NotImplementedError if invoked. This prevents JAX model initialization during import or introspection.

Why this change is needed

  • vLLM recently introduced a strict requirement for model classes to define get_input_embeddings()
    (link).
  • TPU inference uses a dummy PyTorch wrapper to register JAX models into vLLM’s registry.
  • Since this wrapper lacked get_input_embeddings, vLLM failed model registration checks.

Implementation details

  • Added unimplemented_get_input_embeddings() dummy function to the wrapper type.
  • Registered it inside the dynamically created wrapper class.
  • Added a test tests/test_vllm_wrapper.py to ensure:
    • The wrapper defines get_input_embeddings().
    • The method raises NotImplementedError.
    • The class passes is_vllm_model() validation.

Related Issue

Fixes: #951

Tests

pytest -v tests/models/common/test_model_loader.py 

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@kuafou kuafou force-pushed the qi/fix-vllm-model-wrapper branch from 79d6f2d to 80ab177 Compare October 29, 2025 23:22
@py4 py4 requested a review from karan October 30, 2025 19:30
@kuafou kuafou force-pushed the qi/fix-vllm-model-wrapper branch 2 times, most recently from 59e1a8e to bc8fe75 Compare November 5, 2025 18:29
Copy link
Collaborator

@karan karan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR.

@kuafou
Copy link
Contributor Author

kuafou commented Nov 23, 2025

Hi @karan,
Looks like everything has passed and the PR is approved.
When you get a chance, could you help merge it? Thanks! 🙏

@kyuyeunk
Copy link
Collaborator

Hi @karan,

Looks like everything has passed and the PR is approved.

When you get a chance, could you help merge it? Thanks! 🙏

Thank you for the contribution. I'll run the ci manually and will merge when if it all passes.

@kyuyeunk
Copy link
Collaborator

hmm seems like the branch is outdated. can you update it?

Signed-off-by: Allen Jia <kuafou@gmail.com>
Signed-off-by: Allen Jia <kuafou@gmail.com>
@kuafou kuafou force-pushed the qi/fix-vllm-model-wrapper branch from bc8fe75 to 85606aa Compare November 23, 2025 08:40
@kuafou kuafou requested a review from vipannalla as a code owner November 23, 2025 08:40
@kyuyeunk
Copy link
Collaborator

running ci: https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5856

@kuafou
Copy link
Contributor Author

kuafou commented Nov 23, 2025

hmm seems like the branch is outdated. can you update it?

Hi, @kyuyeunk . I had updated this pr.

@kyuyeunk
Copy link
Collaborator

except for already failing tests, verified that all tests passes.

merging the pr.

@kyuyeunk kyuyeunk merged commit add0b5b into vllm-project:main Nov 23, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: vllm model interface now requires get_input_embeddings

3 participants